# functions for 04-types-marginal-plots ----

get_levels <- function(predicted_probabilities, var, pretty_var) {
  pp <- predicted_probabilities[, .(
    preferences = mean(preferences),
    electability = mean(electability),
    expected_utility = mean(expected_utility)),
    .(var = get(var), b)]
  rbind(
    pp[, .(
      covariate = pretty_var,
      type = factor("Preferences", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(preferences),
      q025 = quantile(preferences, .025),
      q975 = quantile(preferences, .975)),
      var],
    pp[, .(
      covariate = pretty_var,
      type = factor("Electability", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(electability),
      q025 = quantile(electability, .025),
      q975 = quantile(electability, .975)),
      var],
    pp[, .(
      covariate = pretty_var,
      type = factor("Expected Utility", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(expected_utility),
      q025 = quantile(expected_utility, .025),
      q975 = quantile(expected_utility, .975)),
      var]
  )
}
get_margins <- function(predicted_probabilities, var, pretty_var,
  val1, val0) {
  pp <- predicted_probabilities[, .(
    preferences = mean(preferences),
    electability = mean(electability),
    expected_utility = mean(expected_utility)),
    .(var = get(var), b)][, .(
      preferences = preferences[var == val1] - preferences[var == val0],
      electability = electability[var == val1] - electability[var == val0],
      expected_utility = expected_utility[var == val1] - expected_utility[var == val0]
    ), b]
  rbind(
    pp[, .(
      covariate = pretty_var,
      type = factor("Preferences", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(preferences),
      q025 = quantile(preferences, .025),
      q975 = quantile(preferences, .975))],
    pp[, .(
      covariate = pretty_var,
      type = factor("Electability", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(electability),
      q025 = quantile(electability, .025),
      q975 = quantile(electability, .975))],
    pp[, .(
      covariate = pretty_var,
      type = factor("Expected Utility", levels = c("Preferences",
        "Electability", "Expected Utility")),
      est = mean(expected_utility),
      q025 = quantile(expected_utility, .025),
      q975 = quantile(expected_utility, .975))]
  )
}

# functions for 03-types.R ----

make_ternary_plot_of_types <- function(fitted_thetas,
  gridlines_x_intercept = c(.25, .5, .75),
  treat = NULL) {
  # assumes preferences in first col,
  # electability in second,
  # expected utility in third
  barycentric <- data.table(
    x = (fitted_thetas[, 2] * .5 + fitted_thetas[, 3]),
    y = (fitted_thetas[, 2] * .5 * sqrt(3))
  )

  if (is.null(treat)) {
    median_theta_values <- apply(fitted_thetas, 2, median)
    median_thetas <- data.table(
      x = median_theta_values[2] * .5 + median_theta_values[3],
      y = median_theta_values[2] * .5 * sqrt(3)
    )
    mean_theta_values <- apply(fitted_thetas, 2, mean)
    mean_thetas <- data.table(
      x = mean_theta_values[2] * .5 + mean_theta_values[3],
      y = mean_theta_values[2] * .5 * sqrt(3)
    )
  }

  if (!is.null(treat)) {
    # median_theta_values <- apply(fitted_thetas[treat == 1, ], 2, mean)
    # median_thetas <- data.table(
    #   x = median_theta_values[2] * .5 + median_theta_values[3],
    #   y = median_theta_values[2] * .5 * sqrt(3)
    # )
    mean_theta_values <- apply(fitted_thetas[treat == 0, ], 2, mean)
    mean_thetas <- data.table(
      x = mean_theta_values[2] * .5 + mean_theta_values[3],
      y = mean_theta_values[2] * .5 * sqrt(3)
    )
  }
  # median_thetas <- data.table(
  #   x = median(barycentric$x),
  #   y = median(barycentric$y)
  # )
  # mean_thetas <- data.table(
  #   x = mean(barycentric$x),
  #   y = mean(barycentric$y)
  # )

  o <- .3 #overlap
  blot_box_L <- data.table(x = c(0-o*.25, 0, .5, 0-o*.25),
    y = c(0-o, 0, sqrt(3) / 2, sqrt(3) / 2))
  blot_box_R <- data.table(x = c(1+o*.25, 1, .5, 1+o*.25),
    y = c(0-o, 0, sqrt(3) / 2, sqrt(3) / 2))
  blot_box_D <- data.table(x = c(0-o, 1+o, 1+o, 0+o),
    y = c(0, 0, -o, -o))
  blot_box_U <- data.table(x = c(0-o, 1+o, 1+o, 0+o),
    y = c(sqrt(3) / 2, sqrt(3) / 2, sqrt(3) / 2+o - .25, sqrt(3) / 2+o - .25))

  # y_intercept <- - sqrt(3) * gridlines_x_intercept
  # pos_grid_lines <- data.table(
  #   x = gridlines_x_intercept,
  #   xend = (sqrt(3) - y_intercept) / (2 * sqrt(3)),
  #   y = rep(0, 3),
  #   yend = y_intercept + (sqrt(3) - y_intercept) / 2
  # )
  # y_intercept <- sqrt(3) * gridlines_x_intercept
  # neg_grid_lines <- data.table(
  #   x = gridlines_x_intercept,
  #   xend = y_intercept / (2 * sqrt(3)),
  #   y = rep(0, 3),
  #   yend = y_intercept / 2
  # )
  # horizontal_grid_lines <- data.table(
  #   x = neg_grid_lines$xend,
  #   xend = rev(pos_grid_lines$xend),
  #   y = neg_grid_lines$yend,
  #   yend = rev(pos_grid_lines$yend)
  # )
  gsize <- .5
  g <- ggplot(barycentric, aes(x, y)) +
    geom_density2d_filled(
      aes(fill = ..level.., color = ..level..),
      contour_var = "ndensity",
      breaks = exp(seq(log(.025), log(.975), length.out = 70)) / exp(log(.975)),
      n = 200) +
    annotate(geom = "segment", x = 0, xend = 0.5, y = 0, yend = 0.5 * sqrt(3),
      color = "black", size = gsize) +
    annotate(geom = "segment", x = 1, xend = 0.5, y = 0, yend = 0.5 * sqrt(3),
      color = "black", size = gsize) +
    annotate(geom = "segment", x = 0, xend = 1, y = 0, yend = 0,
      color = "black", size = gsize) +
    annotate(geom = "segment", x = .25, xend = .625, y = 0, yend = 0.6495191,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = .5, xend = 0.750, y = 0, yend = 0.4330127,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = .75, xend = 0.875, y = 0, yend = 0.2165064,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = .25, xend = 0.125, y = 0, yend = 0.2165064,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = .5, xend = 0.250, y = 0, yend = 0.4330127,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = .75, xend = 0.375, y = 0, yend = 0.6495191,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = 0.125, xend = 0.875, y = 0.2165064, yend = 0.2165064,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = 0.250, xend = 0.750, y = 0.4330127, yend = 0.4330127,
      color = "black", size = gsize/2) +
    annotate(geom = "segment", x = 0.375, xend = 0.625, y = 0.6495191, yend = 0.6495191,
      color = "black", size = gsize/2) +
    geom_point(data = mean_thetas, color = "white", size = 2) +
    geom_polygon(data = blot_box_L, fill = "white", color = NA) +
    geom_polygon(data = blot_box_R, fill = "white", color = NA) +
    geom_polygon(data = blot_box_D, fill = "white", color = NA) +
    geom_polygon(data = blot_box_U, fill = "white", color = NA) +
    # coord_equal(clip = "on", xlim = c(-.0, 1.0), ylim = c(-.0, sqrt(3)/2+.0),
    #   expand = FALSE) +
    scale_fill_manual(values = paste0("gray", 100:31)) +#brewer.pal(9, "Greys")) +
    scale_color_manual(values = paste0("gray", 100:31)) +#brewer.pal(9, "Greys")) +
    theme_bw() +
    theme(
      axis.title = element_blank(),
      axis.text = element_blank(),
      axis.ticks = element_blank(),
      axis.line = element_blank(),
      panel.grid = element_blank(),
      panel.border = element_blank(),
      legend.position = "none")
  r <- .05
  h1 <- .09
  h2 <- h1 * .6
  h3 <- .02
  w1 <- .25
  g  +
    annotate(
      geom = "text",
      x = 1/2 + h1 * (0),
      y = 0 + h1 * (-1),
      angle = 0,
      label = "Preferences") +
    annotate(
      geom = "text",
      x = 1/4 + h1 * (-4/5),
      y = sqrt(3)/4 + h1 * (3/5),
      angle = 60,
      label = "Electability") +
    annotate(
      geom = "text",
      x = 3/4 + h1 * (4/5),
      y = sqrt(3)/4 + h1 * (3/5),
      angle = -60,
      label = "Expected Utility") +
    annotate(geom = "segment",
      # origin + normal vector to axis (for text) + parallel vector to axis (for arrow)
      x = 1/4 + h2 * (-4/5) - w1 * (1/2),
      y = sqrt(3)/4 + h2 * (3/5) - w1 * (sqrt(3)/2),
      xend = 1/4 + h2 * (-4/5) + w1 * (1/2),
      yend = sqrt(3)/4 + h2 * (3/5) + w1 * (sqrt(3)/2),
      size = .25,
      arrow = arrow(type = "closed", angle = 15, length = unit(.125, "inches"))) +
    annotate(geom = "segment",
      x = 1/2 + h2 * (0) - w1 * (-1),
      y = 0 + h2 * (-1) - w1 * (0),
      xend = 1/2 + h2 * (0) + w1 * (-1),
      yend = 0 + h2 * (-1) + w1 * (0),
      size = .25,
      arrow = arrow(type = "closed", angle = 15, length = unit(.125, "inches")))  +
    annotate(geom = "segment",
      x = 3/4 + h2 * (4/5) - w1 * (1/2),
      y = sqrt(3)/4 + h2 * (3/5) - w1 * (-sqrt(3)/2),
      xend = 3/4 + h2 * (4/5) + w1 * (1/2),
      yend = sqrt(3)/4 + h2 * (3/5) + w1 * (-sqrt(3)/2),
      size = .25,
      arrow = arrow(type = "closed", angle = 15, length = unit(.125, "inches"))) +
    annotate(
      geom = "text",
      x = .75 + h3 * (0),
      y = 0 + h3 * (-1),
      angle = 0,
      label = "25%",
      size = 3) +
    annotate(
      geom = "text",
      x = .50 + h3 * (0),
      y = 0 + h3 * (-1),
      angle = 0,
      label = "50%",
      size = 3) +
    annotate(
      geom = "text",
      x = .25 + h3 * (0),
      y = 0 + h3 * (-1),
      angle = 0,
      label = "75%",
      size = 3) +

    annotate(
      geom = "text",
      x = 1/2 * .75 + h3 * (-4/5),
      y = sqrt(3)/2 * .75 + h3 * (3/5),
      angle = 60,
      label = "75%",
      size = 3) +
    annotate(
      geom = "text",
      x = 1/2 * .5 + h3 * (-4/5),
      y = sqrt(3)/2 * .5 + h3 * (3/5),
      angle = 60,
      label = "50%",
      size = 3) +
    annotate(
      geom = "text",
      x = 1/2 * .25 + h3 * (-4/5),
      y = sqrt(3)/2 * .25 + h3 * (3/5),
      angle = 60,
      label = "25%",
      size = 3) +

    annotate(
      geom = "text",
      x = 1/2 + 1/2 * .75 + h3 * (4/5),
      y = sqrt(3)/2 * .25 + h3 * (3/5),
      angle = -60,
      label = "75%",
      size = 3) +
    annotate(
      geom = "text",
      x = 1/2 + 1/2 * .5 + h3 * (4/5),
      y = sqrt(3)/2 * .5 + h3 * (3/5),
      angle = -60,
      label = "50%",
      size = 3) +
    annotate(
      geom = "text",
      x = 1/2 + 1/2 * .25 + h3 * (4/5),
      y = sqrt(3)/2 * .75 + h3 * (3/5),
      angle = -60,
      label = "25%",
      size = 3) +
    coord_equal(clip = "off", xlim = c(-.0, 1.0), ylim = c(-.0, sqrt(3)/2+.0),
      expand = FALSE) +
    theme(plot.margin = unit(c(1,1,1,0), "lines"))
}

calculate_predicted_probabilities <- function(object, X) {
  mean_coefs <- as.vector(t(coef(object)))
  variance_covariance_matrix <- vcov(object)
  nsims <- 1000
  sims <- MASS::mvrnorm(nsims, mean_coefs, variance_covariance_matrix)
  sims_electability <- sims[, 1:(ncol(sims) / 2)]
  sims_expected_utility <- sims[, (ncol(sims) / 2 + 1):ncol(sims)]
  lp_electability <- X %*% t(sims_electability)
  lp_expected_utility <- X %*% t(sims_expected_utility)
  exp_lp_electability <- exp(lp_electability)
  exp_lp_expected_utility <- exp(lp_expected_utility)
  denominators <- 1 + exp_lp_electability + exp_lp_expected_utility
  probs_preferences <- 1 / denominators
  probs_electability <- exp_lp_electability / denominators
  probs_expected_utility <- exp_lp_expected_utility / denominators
  list(probs_preferences, probs_electability, probs_expected_utility)
}

summarize_effects <- function(object, data) {
  X <- model.matrix(as.formula(object$terms), data = data)
  X1 <- X
  X0 <- X
  X1[, "electability_last"] <- 1
  X0[, "electability_last"] <- 0
  predprobs1 <- calculate_predicted_probabilities(object, X1)
  predprobs0 <- calculate_predicted_probabilities(object, X0)
  out <- list()
  x <- colMeans(predprobs1[[1]] - predprobs0[[1]])
  out[["Effect on Preferences Type"]] <- c(mean(x), quantile(x, c(.025, .975)))
  x <- colMeans(predprobs1[[2]] - predprobs0[[2]])
  out[["Effect on Electability Type"]] <- c(mean(x), quantile(x, c(.025, .975)))
  x <- colMeans(predprobs1[[3]] - predprobs0[[3]])
  out[["Effect on Exp. Utility Type"]] <- c(mean(x), quantile(x, c(.025, .975)))
  out
}

summarize_distribution <- function(object, data) {
  X <- model.matrix(as.formula(object$terms), data = data)
  X1 <- X
  X0 <- X
  X1[, "electability_last"] <- 1
  X0[, "electability_last"] <- 0
  predprobs1 <- calculate_predicted_probabilities(object, X1)
  predprobs0 <- calculate_predicted_probabilities(object, X0)
  type_dist1 <- do.call(cbind, lapply(predprobs1, rowMeans))
  type_dist0 <- do.call(cbind, lapply(predprobs0, rowMeans))
  out <- as.data.table(rbind(type_dist1, type_dist0))
  setnames(out, c("Preferences", "Electability", "Expected Utility"))
  out[, treat := rep(1:0, each = nrow(type_dist0))]
  out
}


# for clustered standard errors ----
mclx <- function(fitted_model, cluster) {
  M <- length(unique(cluster))
  N <- length(cluster)
  K <- fitted_model$rank
  dfc <- (M / (M - 1)) * ((N - 1) / (N - K))
  u1 <- apply(sandwich::estfun(fitted_model), 2, function(x) tapply(x, cluster, sum))
  vc1 <- dfc * sandwich::sandwich(fitted_model, meat = crossprod(u1) / N )
  dfcw <- fitted_model$df / (fitted_model$df - (M - 1))
  vcovMCL <- vc1 * dfcw
  se <- diag(vcovMCL) ^ .5
  zval <- coef(fitted_model) / se
  pval <- 2 * pnorm(abs(zval), lower.tail = FALSE)
  cbind(se = se, pval = pval)
}

# based on CBPS:::balance.CBPS ----
balance <- function (formula, data) {
  if (missing(data))
    data <- environment(formula)
  call <- match.call()
  family <- binomial()
  mf <- match.call(expand.dots = FALSE)
  m <- match(c("formula", "data", "na.action"), names(mf), 0L)
  mf <- mf[c(1L, m)]
  mf$drop.unused.levels <- TRUE
  mf[[1L]] <- as.name("model.frame")
  mf <- eval(mf, parent.frame())
  mt <- attr(mf, "terms")
  Y <- model.response(mf, "any")
  if (length(dim(Y)) == 1L) {
    nm <- rownames(Y)
    dim(Y) <- NULL
    if (!is.null(nm))
      names(Y) <- nm
  }
  X <- if (!is.empty.model(mt))
    model.matrix(mt, mf)
  else matrix(NA, NROW(Y), 0L)
  X <- X[, apply(X, 2, sd) > 0]
  treats <- as.factor(Y)
  treat.names <- levels(treats)
  baseline <- matrix(
    rep(0, ncol(X) * 2 * length(treat.names)),
    ncol(X),
    2 * length(treat.names))
  cnames <- array()
  for (i in 1:length(treat.names)) {
    for (j in 1:ncol(X)) {
      baseline[j, i] <- mean(X[which(treats == treat.names[i]), j])
      baseline[j, i + length(treat.names)] <- baseline[j, i]/sd(X[, j])
    }
    cnames[i] <- paste0(treat.names[i], ".mean")
    cnames[length(treat.names) + i] <- paste0(treat.names[i], ".std.mean")
  }
  colnames(baseline) <- cnames
  rownames(baseline) <- colnames(X)
  covars <- 1:nrow(baseline)
  no.treats <- length(levels(as.factor(Y)))
  original.std.mean <- baseline[covars, ]
  no.contrasts <- 1
  abs.mean.ori.contrasts <- matrix(rep(0, no.contrasts * length(covars)),
    length(covars), no.contrasts)
  contrast.names <- array()
  true.contrast.names <- array()
  contrasts <- c()
  covarlist <- c()
  ctr <- 1
  for (i in 1:(no.treats - 1)) {
    for (j in (i + 1):no.treats) {
      abs.mean.ori.contrasts[, ctr] <- abs(original.std.mean[covars,
        i + no.treats] - original.std.mean[covars, j +
            no.treats])
      contrast.names[ctr] <- paste0(i, ":", j)
      true.contrast.names[ctr] <- paste0(levels(as.factor(Y))[i],
        ":", levels(as.factor(Y))[j])
      contrasts <- c(contrasts, rep(true.contrast.names[ctr],
        length(covars)))
      covarlist <- c(covarlist, rownames(baseline))
      ctr <- ctr + 1
    }
  }
  data.table(
    covariate = covarlist,
    abs_diff_std_means = abs.mean.ori.contrasts[, 1])
}


# older ----
scale_score <- function(x) {
  do.call(rbind, lapply(1:nrow(x), function(i) {
    s <- sd(x[i, ])
    if (s == 0) {
      rep(0, ncol(x))
    } else {
      (x[i, ] - mean(x[i, ])) / s
    }
  }))
}

make_remaining_ranking_score <- function(n_remaining_ranks, x, remaining_ranks) {
  (
    do.call(rbind, lapply(1:nrow(x), function(i) {
      sapply(1:n_remaining_ranks, function(r) {
        x[i, remaining_ranks[i, r]]
      })
    }))
  )
}

make_cand_rank_K_to_7 <- function(K, resp, cand, data) {
  do.call(rbind, lapply(1:length(resp), function(i) {
    ranking <- data[subject_id == resp[i], match(candidate, cand)]
    setdiff(1:7, head(ranking, K - 1))
  }))
}

make_rankK <- function(K, resp, cand, data) {
  if (K == 1) {
    sapply(1:length(resp), function(i) {
      data[subject_id == resp[i], match(candidate, cand)][1]
    })
  } else {
    sapply(1:length(resp), function(i) {
      ranking <- data[subject_id == resp[i],
        match(candidate, cand)]
      cand_rankK_to_7 <- setdiff(1:7, head(ranking, K - 1))
      which(cand_rankK_to_7 == tail(ranking, -(K - 1))[1])
    })
  }
}

random_ranking <- function(probs) {
  K <- length(probs)
  ranks <- sample(1:K, 1, prob = probs)
  for (k in 2:(K-1)) {
    ranks <- c(ranks, sample(setdiff(1:K, ranks), 1, prob = probs[-ranks]))
  }
  ranks <- c(ranks, setdiff(1:K, ranks))
  ranks
}

random_ranking_logit <- function(score) {
  K <- length(score)
  probs <- exp(score) / sum(exp(score))
  ranks <- sample(1:K, 1, prob = probs)
  for (k in 2:(K-1)) {
    ranks <- c(ranks, sample(setdiff(1:K, ranks), 1, prob = probs[-ranks]))
  }
  ranks <- c(ranks, setdiff(1:K, ranks))
  ranks
}


# equivalence testing ----
equiv.t.test <- function(x, y, w.x, w.y, alpha = .05, epsilon = .2,
  std.err = "nominal", cluster.x = NULL, cluster.y = NULL)
{
  x = x[!is.na(x)]
  y = y[!is.na(y)]
  weights.x = w.x[!is.na(x)]
  weights.y = w.y[!is.na(y)]
  weights.x = weights.x / sum(weights.x)
  weights.y = weights.y / sum(weights.y)
  dbar <- wtd.mean(x, weights.x) - wtd.mean(y, weights.y)
  m <- as.double(length(x))
  n <- as.double(length(y))
  N <- m+n
  x.var <- wtd.var(x, weights = weights.x, normwt = TRUE)
  y.var <- wtd.var(y, weights = weights.y, normwt = TRUE)
  non.cent <- (m * n * epsilon^2)/N
  critical.const <- sqrt(qf(alpha, 1, N - 2, non.cent))
  se = sqrt((m-1)*x.var + (n-1)*y.var) / sqrt(m*n * (N-2)/N)
  df = N - 2
  t.stat <-  dbar / se
  p = pf(abs(t.stat) ^ 2, 1, df , non.cent)
  obs_smd = (wtd.mean(x, weights = weights.x) - wtd.mean(y,
    weights = weights.y))/sqrt(y.var)
  inverted <- try(
    uniroot(function(x) {
      pf(
        abs(t.stat)^2,
        1,
        N-2,
        ncp = (m * n * x ^ 2) / N) - ifelse(
          pf(abs(t.stat) ^ 2, 1, N - 2, ncp = (m * n * 0 ^ 2) / N) < alpha,
          pf(abs(t.stat) ^ 2, 1, N - 2, ncp = (m * n * obs_smd ^ 2) / N),
          alpha)
    },
      c(0, 10 * abs(t.stat)),
      tol = 0.0001)$root,
    silent = TRUE
  )
  if(class(inverted) == "try-error") {
    inverted = NA
  }
  rej = abs(t.stat) <= critical.const
  # return(
  #   list(
  #     t.stat = t.stat,
  #     critical.const = critical.const,
  #     power = 2 * pt(critical.const, N - 2) - 1,
  #     rej = rej,
  #     p = p,
  #     inverted = inverted))


  c(inverted,
    round(p, 4),
    obs_smd, #res$obs_smd,
    dbar, #res$obs_diff,
    2 * pt(critical.const, N - 2) - 1, #res$power,
    sqrt(y.var) # res$sd
  )
}

run_equiv <- function(X, Tr, w, epsilon.method = "std.effect", Y = NULL,
  custom.epsilon = NULL, std.err = "nominal", type = "equiv.t.test",
  fdr_correct = FALSE, cluster.x = NULL, cluster.y = NULL)
{
  switch(epsilon.method,
    std.effect = {
      tol =
        abs(mean(Y[Tr == 1], na.rm = TRUE) - mean(Y[Tr == 0], na.rm = TRUE)) /
        sd(Y[Tr == 0], na.rm = TRUE)
    },
    custom = {
      if (is.null(custom.epsilon))
        stop("ERROR: Must enter a custom epsilon value to use the 'custom' epsilon.method.")
      tol = custom.epsilon
    },
    strict = {
      tol = 0.36
    },
    liberal = {
      tol = 0.74
    },
    stop("ERROR: 'epsilon.method' not set to a valid option")
  )

  ranges <- rep(tol, ncol(X))
  names(ranges) <- names(X)

  tests <- do.call(rbind, lapply(names(X), function(var) {
    equiv.t.test(
      X[Tr == 1, get(var)], X[Tr == 0, get(var)],
      w[Tr == 1], w[Tr == 0],
      epsilon = ranges[var],
      std.err = std.err, cluster.x = cluster.x, cluster.y = cluster.y)
  }))

  p.vals = unlist(tests[, 2])
  names(p.vals) = names(ranges)
  power = unlist(tests[, 5])
  names(power) = names(ranges)
  inverted = unlist(tests[, 1])
  names(inverted) = names(ranges)
  sd = unlist(tests[,6 ])
  names(sd) = names(ranges)
  inverted.scaled = unlist(lapply(names(inverted),
    function(var) inverted[var] * sd[var]))
  names(inverted.scaled) = names(ranges)
  observed.smd = unlist(tests[, 3])
  names(observed.smd) = names(ranges)
  observed.diff = unlist(tests[, 4])
  names(observed.diff) = names(ranges)
  # conduct BH FDR adjustment
  if (fdr_correct) {
    p.vals = p.adjust(p.vals, method = "BH")
  }

  # return(list(tol = ranges, inverted = inverted,
  #   inverted.scaled = inverted.scaled, p.vals = p.vals,
  #   observed.diff = observed.diff, observed.smd = observed.smd, power = power))

  return(data.table(
    display.names = names(inverted),
    inverted = inverted,
    inverted.scaled = inverted.scaled, p.vals = p.vals,
    observed.diff = observed.diff, observed.smd = observed.smd, power = power))
}

generate_plot <- function(equiv_tests, panel.widths=c(1,
  1, 5, 1, 1),
  display.names = NULL, var.rounding = 1, pval.rounding = 2,
  fdr_correct = FALSE, title_text = "")
{
  .e <- environment()
  if(!is.null(display.names)) {
    equiv_tests$names = factor(display.names, levels = rev(display.names))
  } else {
    equiv_tests$names = factor(row.names(equiv_tests), levels = rev(row.names(equiv_tests)))
  }
  equiv_tests$const = rep(1, nrow(equiv_tests))
  equiv_tests = equiv_tests[nrow(equiv_tests):1, ]


  g = ggplot(equiv_tests, aes(x = names) )
  print(length(unique(equiv_tests$tol)))
  g = g + geom_linerange(aes(ymin = -1 * inverted, ymax = inverted), size = 5, color = "darkgray", alpha = 0.9)
  if(length(unique(equiv_tests$tol)) > 1){
    g = g + geom_linerange(aes(ymin = -1 * tol, ymax = tol), size = 10, color = 'gray')
  } else {
    print("here")
    print(unique(equiv_tests$tol))
    lines = unique(equiv_tests$tol)
    g = g + geom_hline(yintercept = c(-1 * lines, lines), linetype = 2, size = 0.75)
  }
  g = g  + scale_shape_identity() + geom_point(aes(y = observed.smd), color = "black", shape = 18, size = 4)
  g = (g + theme_bw() + coord_flip()
    + theme(
      axis.text.y = element_blank()
      , axis.ticks.y = element_blank()
      , axis.title.y = element_blank()
      , axis.text.x = element_text(size = 10)
      , axis.title.x = element_text(size = 10),
      plot.title = element_text(size = 10, hjust = .5))
    + labs(x=NULL, y = paste("Equivalence Range (in standard deviations \u03C3)"), font=5)
  )
  if(length(unique(equiv_tests$tol)) == 1) {
    g = g + ggtitle(paste0("Equivalence Tests  \n", title_text, "\n",
      "Equivalence Range: +/- ", round(unique(equiv_tests$tol), 2), "\u03C3 \n"))
  } else {
    g = g + ggtitle(paste0("\n \n \nEquivalence Tests"))
  }

  g_inv = ggplot(equiv_tests, aes(x = names, y = const, label = sprintf(paste0("%0.0", var.rounding, "f"), round(inverted.scaled, var.rounding))), environment = .e)
  g_inv = g_inv + geom_text()
  g_inv = (g_inv + theme_bw() + coord_flip()
    + theme(panel.grid.minor=element_blank()
      , panel.grid.major=element_blank()
      #, axis.line.x = element_blank()
      , axis.text.x = element_text(color = "white", size = 10)
      , axis.ticks.x = element_line(color = "white")
      , axis.text.y = element_blank()
      , axis.ticks.y = element_blank()
      , axis.title.x = element_text(size = 10),
      plot.title = element_text(size = 10)
    )
    + ylim(1-.05, 1.05)
    + ggtitle("Equivalence\nConfidence\nInterval (+/-)\n(Scale of Var)")
    + labs(y = " ", x = NULL)
  )

  g_obs = ggplot(equiv_tests, aes(x = names, y = const, label = sprintf(paste0("%0.0", var.rounding, "f"), round(observed.diff, var.rounding))), environment = .e)
  g_obs = g_obs + geom_text()
  g_obs = (g_obs + theme_bw() + coord_flip()
    + theme(panel.grid.minor=element_blank()
      , panel.grid.major=element_blank()
      #, axis.line.x = element_blank()
      , axis.text.x = element_text(color = "white", size = 10)
      , axis.ticks.x = element_line(color = "white")
      , axis.text.y = element_blank()
      , axis.ticks.y = element_blank()
      , axis.title.x = element_text(size = 10),
      plot.title = element_text(size = 10)
    )
    + ylim(1-.05, 1.05)
    + ggtitle("Observed\nMean\nDifference\n(Scale of Var)")
    + labs(y = " ", x = NULL)
  )

  g_pval = ggplot(equiv_tests, aes(x = names, y = const, label = round(p.vals, pval.rounding)), environment = .e)
  g_pval = g_pval + geom_text()
  g_pval = (g_pval + theme_bw() + coord_flip()
    + theme(panel.grid.minor=element_blank()
      , panel.grid.major=element_blank()
      #, axis.line.x = element_blank()
      , axis.text.x = element_text(color = "white", size = 10)
      , axis.ticks.x = element_line(color = "white")
      , axis.text.y = element_blank()
      , axis.ticks.y = element_blank()
      , axis.title.x = element_text(size = 10),
      plot.title = element_text(size = 10)
    )
    + ylim(1-.05, 1.05)
    + ggtitle(ifelse(fdr_correct, "\nFDR\nCorrected\nP-value", "\n\n\nP-value"))
    + labs(y = " ", x = NULL)
  )

  g_var = ggplot(equiv_tests, aes(x = names, y = const, label = names))
  g_var = g_var + geom_text()
  g_var = (g_var + theme_bw() + coord_flip()
    + theme(panel.grid.minor=element_blank()
      , panel.grid.major=element_blank()
      #, axis.line.x = element_blank()
      , axis.text.x = element_text(color = "white", size = 10)
      , axis.ticks.x = element_line(color = "white")
      , axis.text.y = element_blank()
      , axis.ticks.y = element_blank()
      , panel.border = element_blank()
      , axis.title.x = element_text(size = 10),
      plot.title = element_text(size = 10)
    )
    + ylim(1-.05, 1.05)
    + ggtitle("\n \n \nVariable") + theme(plot.title = element_text(hjust = 0.5))
    + labs(y = " ", x = NULL)
  )


  require(cowplot)
  cs <- c(0, cumsum(panel.widths) / sum(panel.widths))
  x <- head(cs, -1)
  w <- tail(cs, -1) - head(cs, -1)
  ggdraw() +
    draw_plot(g_var, x = x[1], y = 0, height = .9, width = w[1]) +
    draw_plot(g_obs, x = x[2], y = 0, height = .9, width = w[2]) +
    draw_plot(g, x = x[3], y = 0, height = .9, width = w[3]) +
    draw_plot(g_inv, x = x[4], y = 0, height = .9, width = w[4]) +
    draw_plot(g_pval, x = x[5], y = 0, height = .9, width = w[5]) +
    draw_plot_label(x = .5, y = .96, title_text, hjust = .5)
  # grid.arrange(#g_var,
  #   g_obs, g, g_inv, g_pval, ncol=4, nrow=1,
  #   widths= panel.widths, heights=c(4))
}
